#include "kdtree.h"
#include "utils.h"
#include "timer.h"

#include "raytracer.h"
#include "glintersectiondebug.h"

#include <algorithm>
#include <iostream>
#include <limits>

using namespace utils;
namespace rendersystem {

#define EPS_INTERSECTION 0.001

bool KdTreeNode::intersectTriangles(RayTracer *tracer,
                                    const RayTracer::Ray& theRay,
                                    RayTracer::Intersection &theIntersection) {
    RayTracer::Intersection it;

    //@BEGIN
    for (std::vector<RayTracer::RTTriangle*>::iterator t=triangles.begin(); t!=triangles.end(); ++t){
        if ( (*t)->intersect(tracer,theRay,it) )
            if (it < theIntersection){
                theIntersection = it;
            }
    }
    //@END
    return (theIntersection.valid);
}

KdTree::KdTree(RayTracer *tracer, int theDepthLimit, int theTriangleLimit) :
    sceneHolder(tracer), depthLimit(theDepthLimit), triangleLimit(
        theTriangleLimit) {
    RayTracer::RTBbox rootBbox;
    rootNode = new KdTreeNode;
    for (std::vector<RayTracer::RTTriangle *>::iterator t =
         sceneHolder->mTriangles.begin(); t != sceneHolder->mTriangles.end();
         ++t) {
        rootNode->addTriangle(*t);
        rootBbox.add(sceneHolder->mVertices[(*t)->vertices[0]]->position);
        rootBbox.add(sceneHolder->mVertices[(*t)->vertices[1]]->position);
        rootBbox.add(sceneHolder->mVertices[(*t)->vertices[2]]->position);
    }
    rootNode->setBbox(rootBbox);
    rootNode->updateTriangleCount();

    glm::vec3 min;
    glm::vec3 max;
    rootBbox.getBounds(min, max);
    std::cerr << "Root node : " << rootNode->getTriangleCount() << std::endl;
    std::cerr << "\tBbox : " << min << " // " << max << std::endl;
    if ((rootNode->getTriangleCount() > triangleLimit))
        subdivide(rootNode);


    std::cerr << "fini la construction" << std::endl;
}

KdTree::~KdTree() {
    delete rootNode;
}

// methode de subdivision 0 = milieu, 1 = median
int MODE_DE_DIVISION_KDTREE = 0;

//selection la position de coupe et l'axe.
void KdTree::selectSplit(KdTreeNode *theNode, int &axis, float &pos) {
    glm::vec3 extend;
    theNode->getBbox().getSize(extend);
    if (extend[0] >= extend[1]) {
        if (extend[2] >= extend[0])
            axis = 2;
        else
            axis = 0;
    } else {
        if (extend[2] >= extend[1])
            axis = 2;
        else
            axis = 1;
    }

    glm::vec3 min;
    glm::vec3 max;
    theNode->getBbox().getBounds(min, max);

    pos = (min[axis] + max[axis])/2.0;
}
//divise le noeud en 2
void KdTree::subdivide(KdTreeNode *theNode) {
    // determine split axis and position
    int axis;
    float splitpos;
    selectSplit(theNode, axis, splitpos);

    theNode->setAxis(axis);
    theNode->setSplit(splitpos);

    // compute bbox of childs
    glm::vec3 nbboxmin;
    glm::vec3 nbboxmax;
    theNode->getBbox().getBounds(nbboxmin, nbboxmax);
    glm::vec3 leftBoxMax = nbboxmax;
    glm::vec3 rightBoxMin = nbboxmin;
    leftBoxMax[axis] = splitpos;
    rightBoxMin[axis] = splitpos;

    // create child nodes
    KdTreeNode *leftNode = new KdTreeNode(theNode->getDepth() +1);
    leftNode->reserveTriangles(theNode->getTriangleCount());
    KdTreeNode *rightNode = new KdTreeNode(theNode->getDepth() +1);
    rightNode->reserveTriangles(theNode->getTriangleCount());

    // add triangles to child
    for (std::vector<RayTracer::RTTriangle *>::iterator t=theNode->getTriangles().begin(); t!=theNode->getTriangles().end(); ++t) {
        RayTracer::RTTriangle *triangle=*t;
        if (triangle->max[axis] <= splitpos){
            // add to left node only
            leftNode->addTriangle(triangle);
        } else if (triangle->min[axis] >= splitpos){
            // add to right node only
            rightNode->addTriangle(triangle);
        } else {
            // add to both left and right
            leftNode->addTriangle(triangle);
            rightNode->addTriangle(triangle);
        }
    }
    theNode->getTriangles().clear();

    RayTracer::RTBbox leftBox(nbboxmin, leftBoxMax);
    leftNode->setBbox(leftBox);
    leftNode->updateTriangleCount();

    RayTracer::RTBbox rightBox(rightBoxMin, nbboxmax);
    rightNode->setBbox(rightBox);
    rightNode->updateTriangleCount();

    // set new child nodes
    theNode->setLeaf(false);
    theNode->setLeft(leftNode);
    theNode->setRight(rightNode);

    // recurse on child nodes

    if ( (theNode->getDepth() < depthLimit-1)) {
        if ( (leftNode->getTriangleCount() > triangleLimit) ){
            subdivide(leftNode);
        }

        if ( (rightNode->getTriangleCount() > triangleLimit) ){
            subdivide(rightNode);
        }
    }
}

bool KdTree::intersect(RayTracer *tracer, const RayTracer::Ray& theRay,
                       RayTracer::Intersection &theIntersection) {

    //@BEGIN
    // clip ray to root bbox
    float tmin = 0;
    float tmax = theIntersection.t;
    // Clip ray (in tmin and tmax) to the scene bbox
    if (!rootNode->getBbox().intersect(theRay, tmin, tmax))
        return false; // pas d'intersection avec la boite de la scene

    // initialize stack
    StackNode *initialNode = new StackNode(rootNode, tmin, tmax);
    theNodeStack.push(initialNode);

    // traverse nodes
    StackNode *currentNode;
    while (!theNodeStack.empty()) {
        currentNode = theNodeStack.top();
        theNodeStack.pop();
        if (traverse(currentNode, tracer, theRay, theIntersection)) {
            while (!theNodeStack.empty()) {
                StackNode *toDelete = theNodeStack.top();
                theNodeStack.pop();
                delete toDelete;
            }
            if (RayTracer::intersectionDebug)
                RayTracer::intersectionDebug->pushBox(
                            currentNode->node->getBbox(), true);
            delete currentNode;
            return true;
        }
        if (RayTracer::intersectionDebug)
            RayTracer::intersectionDebug->pushBox(currentNode->node->getBbox(),
                                                  false);
        delete currentNode;
    }
    //@END
    return false;
}

bool KdTree::traverse(StackNode *currentNode, RayTracer *tracer,
                      const RayTracer::Ray& theRay,
                      RayTracer::Intersection &theIntersection) {

    //@BEGIN
    KdTreeNode *theNode = currentNode->node;
    if (theNode->isleaf()) {
        float tmax = currentNode->tmax; //+EPS_INTERSECTION;

        theIntersection.t = tmax;
        return theNode->intersectTriangles(tracer, theRay, theIntersection);
    } else {
        // creer les stacknode parcourus et les empiler dans le bon ordre
        int axis = theNode->getAxis();
        float split = theNode->getSplit();
        float tmin = currentNode->tmin;
        float tmax = currentNode->tmax;

        if (theRay.mDirection[axis] > 0) {
            float tsplit = (split - theRay.mOrigin[axis])
                    * theRay.mInvDirection[axis];
            if (tsplit <= tmin) {
                StackNode *newNode = new StackNode(theNode->getRight(), tmin,
                                                   tmax);
                theNodeStack.push(newNode);
            } else if (tsplit >= tmax) {
                StackNode *newNode = new StackNode(theNode->getLeft(), tmin,
                                                   tmax);
                theNodeStack.push(newNode);
            } else {
                StackNode *newNode = new StackNode(theNode->getRight(), tsplit,
                                                   tmax);
                theNodeStack.push(newNode);
                newNode = new StackNode(theNode->getLeft(), tmin, tsplit);
                theNodeStack.push(newNode);
            }
        } else if (theRay.mDirection[axis] < 0) {
            float tsplit = (split - theRay.mOrigin[axis])
                    * theRay.mInvDirection[axis];
            if (tsplit >= tmax) {
                StackNode *newNode = new StackNode(theNode->getRight(), tmin,
                                                   tmax);
                theNodeStack.push(newNode);
            } else if (tsplit <= tmin) {
                StackNode *newNode = new StackNode(theNode->getLeft(), tmin,
                                                   tmax);
                theNodeStack.push(newNode);
            } else {
                StackNode *newNode = new StackNode(theNode->getLeft(), tsplit,
                                                   tmax);
                theNodeStack.push(newNode);
                newNode = new StackNode(theNode->getRight(), tmin, tsplit);
                theNodeStack.push(newNode);
            }
        } else {
            // theRay.mDirection[axis] == 0
            if (theRay.mOrigin[axis] <= split) {
                StackNode *newNode = new StackNode(theNode->getLeft(), tmin,
                                                   tmax);
                theNodeStack.push(newNode);
            } else {
                StackNode *newNode = new StackNode(theNode->getRight(), tmin,
                                                   tmax);
                theNodeStack.push(newNode);
            }
        }
        return false;
    }
    //@END
    return false;
}

} // namespace;
